# Copyright 2021 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""the module is used to register mindspore builtin data augment apis."""

import mindspore.dataset.vision.c_transforms as C
import mindspore.dataset.vision.py_transforms as PY

from utils.class_factory import ClassFactory, ModuleType


@ClassFactory.register(ModuleType.PIPELINE)
class Decode:
    """Wrap mindspore Decode operation"""

    def __init__(self, decode_mode='C'):
        self.decode_mode = decode_mode
        if decode_mode == 'C':
            self.decode = C.Decode()
        else:
            self.decode = PY.Decode()

    def __call__(self, results):
        image = results['image']
        img_dec = self.decode(image)
        results['image'] = img_dec

        # shape format as HW
        if self.decode_mode == 'C':
            results['image_shape'] = img_dec.shape[:2]
        else:
            results['image_shape'] = (img_dec.size[1], img_dec.size[0])
        return results


@ClassFactory.register(ModuleType.PIPELINE)
class HWC2CHW:
    """ Wrap mindspore HWC2CHW operation. """

    def __init__(self):
        self.hwc2chw = C.HWC2CHW()

    def __call__(self, results):
        image = results['image']
        img_des = self.hwc2chw(image)
        results['image'] = img_des
        return results
